open Base

type ('a, 'l) tree =
  | Arm of 'a * ('a, 'l) tree list
  | Leaf of 'l

type ('a, 'l) t = ('a, 'l) tree

module type SHOW_EQ = sig
  type t
  val to_string : t -> string
  val equal : t -> t -> bool
end

module Make_show_eq(A: SHOW_EQ)(L: SHOW_EQ) = struct
  let rec to_string = function
    | Arm(a, ts) ->
       Printf.sprintf "%s[%s]"
         (A.to_string a)
         (String.concat ~sep:"; "
            (List.map ts ~f:to_string))
    | Leaf(l) ->
       L.to_string l

  let rec equal a b = match a, b with
    | Arm(a1, [t1]), Arm(a2, [t2]) -> A.equal a1 a2 && equal t1 t2
    | Arm(a1, ts1),  Arm(a2, ts2)  -> A.equal a1 a2 && List.equal ts1 ts2 ~equal
    | Leaf(l1),      Leaf(l2)      -> L.equal l1 l2
    | _,             _             -> false
end


(*** zipper ***)

type path = int list

module Zipper = struct
  type ('a, 'l) t =
    { ctx : ('a, 'l) arm_ctx list
    ; foc : ('a, 'l) tree }

  and ('a, 'l) arm_ctx =
    { tag : 'a
    ; left_rev : ('a, 'l) tree list
    ; right : ('a, 'l) tree list }

  module Error = struct
    type t = Bad_index | Not_arm | No_parent
    let to_string = function
      | Bad_index -> "node index out of range"
      | Not_arm -> "cannot navigate to node"
      | No_parent -> "trying to navigate above root"

    let or_fail v = match v with
      | Ok(x) -> x
      | Error(e) -> failwith (to_string e)
    [@@ocaml.inline always]
  end

  (*** basics ***)

  let of_tree foc = { ctx = [] ; foc } [@@ocaml.inline always]
  let get {foc; _} = foc [@@ocaml.inline always]
  let set foc z = { z with foc } [@@ocaml.inline always]

  (*** movement ***)

  let up = function
    | { ctx = [] ; _ } -> Error(Error.No_parent)
    | { ctx = {tag;left_rev;right}::ctx ; foc } ->
       let foc = Arm(tag, List.rev_append left_rev (foc::right)) in
       Ok {ctx;foc}

  let down i = function
    | { ctx ; foc = Arm(tag, ts) } ->
       let rec split left_rev = function
         | (_, []) -> Error(Error.Bad_index)
         | (0, foc::right) -> Ok { foc ; ctx = {tag;left_rev;right}::ctx }
         | (i, t::ts) -> split (t::left_rev) (i - 1, ts)
       in
       if i < 0 then Error(Error.Bad_index)
       else split [] (i, ts)
    | _ ->
       Error(Error.Not_arm)

  let side amt = function
    | { ctx = [] ; _ } -> Error(Error.No_parent)
    | { ctx = {tag;left_rev;right}::ctx ; foc } ->
       let rec move = function
         | 0, left_rev, right, foc ->
            Ok { ctx = {tag;left_rev;right}::ctx ; foc }
         | i, t::left, right, foc when i < 0 ->
            move (i + 1, left, foc::right, t)
         | i, left, t::right, foc when i > 0 ->
            move (i - 1, foc::left, right, t)
         | _, _, _, _ ->
            Error(Error.Bad_index)
       in
       move (amt, left_rev, right, foc)

  let nav ?(prev=[]) ~targ z =
    let open Result.Monad_infix in
    (* move up the tree according to [p_up] then move down according to [p_down]. *)
    let ladder p_up p_down =
      List.fold_result p_up   ~init:z ~f:(fun z _ -> up z) >>= fun z ->
      List.fold_result p_down ~init:z ~f:(fun z i -> down i z)
    in
    (* skip prefixes of the paths that are the same; move horizontally if the paths are
       the same up until the last index *)
    let rec elim_prefix = function
      | [i], [j] ->
         side (j - i) z
      | i::p1, j::p2 when Int.(i = j) ->
         elim_prefix (p1, p2)
      | p1, p2 ->
         ladder p1 p2
    in
    elim_prefix (prev, targ)

  let rec to_tree z =
    match up z with
    | Ok(z) -> to_tree z
    | Error(_) -> z.foc

  (*** "_exn" variants ***)

  let up_exn z = Error.or_fail (up z)
  let down_exn i z = Error.or_fail (down i z)
  let side_exn amt z = Error.or_fail (side amt z)
  let nav_exn ?prev ~targ z = Error.or_fail (nav ?prev ~targ z)
end
